In [106]:
import pickle
import jax

import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
In [107]:
x = jnp.linspace(-15,25,10000)
In [108]:
with open('./results_data/logistic_full_rank_Ajax','rb') as f:
    variational  = pickle.load(f)
In [109]:
params = variational.get_params()
loc_m, scale = jax.tree_leaves(variational.transform_dist(params['beta']))
scale = jnp.dot(scale, scale.T)
In [110]:
all_pdf = []
for i in range(3):
    y = tfd.Normal(loc = loc_m[i],scale = jnp.sqrt(scale[i][i])).prob(x)
    all_pdf.append(y)
In [111]:
with open('./results_data/logistic_regression_laplace','rb') as f:
    laplace = pickle.load(f)
In [112]:
loc_m = laplace.loc
std = laplace.stddev()
In [113]:
for i in range(3):
    y = tfd.Normal(loc = loc_m[i],scale = std[i]).prob(x)
    all_pdf.append(y)
In [114]:
with open('./results_data/MCMC_Blackjax','rb') as f:
      black_samples = pickle.load(f)
In [115]:
for i in range(3):
    kde_black = gaussian_kde(black_samples.position['theta'][:,i],bw_method=0.3)
    pdf_black = kde_black(x)
    all_pdf.append(pdf_black)
In [119]:
all_label = ['Ajax VI theta0']*x.shape[0] + ['Ajax VI theta1']*x.shape[0] + ['Ajax VI theta2'] *x.shape[0]+ ['Laplace theta0']*x.shape[0] + ['Laplace theta1']*x.shape[0] + ['Laplace theta2'] *x.shape[0]+['MCMC theta0']*x.shape[0]+['MCMC theta1']*x.shape[0]+['MCMC theta2']*x.shape[0]
In [120]:
all_pdf = jnp.array(all_pdf).reshape((-1))
In [121]:
x_repeated = jnp.tile(x,9)
to_df = {
    "theta":x_repeated,
    "PDF":all_pdf,
    "label": all_label

}
df = pd.DataFrame(to_df)

fig = px.line(to_df,"theta","PDF",color="label",title="logistic regression") 
fig.show()
fig.write_html("logistic_reg_result_plotly.html")